# In[1]
# -*- coding: utf-8 -*-
"""
Construct the dataset data.csv and findings for Tables 3, 4, 5, and S1.

"Evaluating (weighted) dynamic treatment effects by double machine learning"
by Hugo Bodory, Martin Huber, and Lukáš Lafférs.
This file constructs the dataset data.csv by importing ten SAS data files. This
file also creates the descriptive statistics shown in Tables 3, 4, 5, and
S1 of the paper and the online supplement. The names of the files imported for
reading are: impact.sas7bdat, rand_dat.sas7bdat, jc_tl.sas7bdat,
edtrn_tl.sas7bdat, fu12_raw.sas7bdat, base_raw.sas7bdat, baseline.sas7bdat,
key_vars.sas7bdat, other_tl.sas7bdat, and empl_tl.sas7bdat.
"""

# In[2]: libraries

import pandas as pd  # data bases
import numpy as np  # array operations
import re  # regular expressions
import pyreadstat  # pandas dataframes and metadata objects


# In[3]: importing and merging data

# import data and restore case sensitivity of variable names before merging
# (pyreadstat.read_sas7bdat() transforms strings to uppercase letters)
df_impact, impact = pyreadstat.read_sas7bdat('impact.sas7bdat')
df_impact.columns = df_impact.columns.str.lower()

df_rand_dat, rand_dat = pyreadstat.read_sas7bdat('rand_dat.sas7bdat')
df_rand_dat.columns = df_rand_dat.columns.str.lower()

df_jc_tl, jc_tl = pyreadstat.read_sas7bdat('jc_tl.sas7bdat')
list_df_jc_tl_columns = list(df_jc_tl.columns)
for i in [0, -1]:
    list_df_jc_tl_columns[i] = list_df_jc_tl_columns[i].lower()
df_jc_tl.columns = list_df_jc_tl_columns

df_edtrn_tl, edtrn_tl = pyreadstat.read_sas7bdat('edtrn_tl.sas7bdat')
list_df_edtrn_tl_columns = list(df_edtrn_tl.columns)
for i in [0, 209]:
    list_df_edtrn_tl_columns[i] = list_df_edtrn_tl_columns[i].lower()
df_edtrn_tl.columns = list_df_edtrn_tl_columns

df_fu12_raw, fu12_raw = pyreadstat.read_sas7bdat('fu12_raw.sas7bdat')
df_fu12_raw.columns = df_fu12_raw.columns.str.lower()

df_base_raw, base_raw = pyreadstat.read_sas7bdat('base_raw.sas7bdat')
df_base_raw.columns = df_base_raw.columns.str.lower()

df_baseline, baseline = pyreadstat.read_sas7bdat('baseline.sas7bdat')
list_df_baseline_columns = list(df_baseline.columns)
for i in [0, 1, 2, 13, 14, 15, 17, 18, 19, 20, 21, 24, 27, 32, 35, 36, 41, 42,
          45, 49, 50, 53, 54, 65, 66, 67, 68, 103, 107, 108, 109, 110, 111,
          112, 113, 114, 124, 155, 156, 164, 165, 166, 167, 169, 170, 171, 197,
          198, 199, 200, 202, 203, 204, 205, 206, 207, 208, 209, 379]:
    list_df_baseline_columns[i] = list_df_baseline_columns[i].lower()
df_baseline.columns = list_df_baseline_columns

df_key_vars, key_vars = pyreadstat.read_sas7bdat('key_vars.sas7bdat')
df_key_vars.columns = df_key_vars.columns.str.lower()

df_other_tl, other_tl = pyreadstat.read_sas7bdat('other_tl.sas7bdat')
list_df_other_tl_columns = list(df_other_tl.columns)
list_df_other_tl_columns[0] = list_df_other_tl_columns[0].lower()
df_other_tl.columns = list_df_other_tl_columns

df_empl_tl, empl_tl = pyreadstat.read_sas7bdat('empl_tl.sas7bdat')
list_df_empl_tl_columns = list(df_empl_tl.columns)
list_df_empl_tl_columns[0] = list_df_empl_tl_columns[0].lower()
df_empl_tl.columns = list_df_empl_tl_columns

# merging the datasets
df = df_impact.merge(df_rand_dat, how='left', on=['mprid'])
df = df.merge(df_jc_tl, how='left', on=['mprid'])
df = df.merge(df_edtrn_tl, how='left', on=['mprid'])
df = df.merge(df_fu12_raw, how='left', on=['mprid'])
df = df.merge(df_base_raw, how='left', on=['mprid'])
df = df.merge(df_baseline, how='left', on=['mprid'])
df = df.merge(df_key_vars[['mprid', 'female']], how='left', on=['mprid'])
df = df.merge(df_other_tl, how='left', on=['mprid'])
df = df.merge(df_empl_tl, how='left', on=['mprid'])

# variable names and labels
df_name_label = \
    pd.concat(
        [pd.DataFrame({'name': df_impact.columns,
                       'varlab': impact.column_labels}),
         pd.DataFrame({'name': df_rand_dat.columns,
                       'varlab': rand_dat.column_labels}),
         pd.DataFrame({'name': df_jc_tl.columns,
                       'varlab': jc_tl.column_labels}),
         pd.DataFrame({'name': df_edtrn_tl.columns,
                       'varlab': edtrn_tl.column_labels}),
         pd.DataFrame({'name': df_fu12_raw.columns,
                       'varlab': fu12_raw.column_labels}),
         pd.DataFrame({'name': df_base_raw.columns,
                       'varlab': base_raw.column_labels}),
         pd.DataFrame({'name': df_baseline.columns,
                       'varlab': baseline.column_labels}),
         pd.DataFrame({'name': df_key_vars.columns,
                       'varlab': key_vars.column_labels}),
         pd.DataFrame({'name': df_other_tl.columns,
                       'varlab': other_tl.column_labels}),
         pd.DataFrame({'name': df_empl_tl.columns,
                       'varlab': empl_tl.column_labels})]).drop_duplicates(
                           subset='name', keep='first')

print('controls,treated in impact study of Job Corps program')
print(df.status.value_counts(dropna=False).sort_index()[:2])

# In[4]: participating in job corps
# JCHX: in job corps in respective weeks

liste1 = ['JCHX'+str(i) for i in range(1, 53)]  # year 1
liste2 = ['JCHX'+str(i) for i in range(53, 105)]  # year 2

# dictionary for mapping the string values of JCHX to treatments
f = {'': 0, '1': 1, '12': 1, '13': 1, '1A': 1, '1B': 1, 'A1': 1, 'B1': 1,
     '2': 1, '21': 1, '23': 1, '2A': 1, '2B': 1, 'A2': 1, 'B2': 1, '3': 1,
     '31': 1, '32': 1, '3A': 1, '3B': 1, 'A3': 1, 'B3': 1, 'A': -0.001,
     'B': -0.001, 'AB': -0.001, 'X': -0.001}

# year 1
temp_array = np.zeros(len(df))
for var in liste1:
    temp_array = temp_array + df[var].map(f).to_numpy()
temp_array[temp_array > 0] = 1
temp_array[temp_array < 0] = -1
df = pd.concat([df, pd.DataFrame(temp_array, columns=['jchx_1_52'])], axis=1)

# year 2
temp_array = np.zeros(len(df))
for var in liste2:
    temp_array = temp_array + df[var].map(f).to_numpy()
temp_array[temp_array > 0] = 1
temp_array[temp_array < 0] = -1
df = pd.concat([df, pd.DataFrame(temp_array, columns=['jchx_53_104'])], axis=1)

# In[5]: hours spent in programs
# JCAH: weekly hours of academic education in job corps
# JCVH: weekly hours of vocational training in job corps

liste = ['JCAH', 'JCVH']

for item in liste:
    liste1 = [item+str(i) for i in range(1, 53)]  # year 1
    liste2 = [item+str(i) for i in range(53, 105)]  # year 2
    temp_array = np.zeros(len(df))
    for var in liste1:
        temp_array = temp_array + df[var].fillna(-0.001).to_numpy()
    df = pd.concat([df, pd.DataFrame(temp_array, columns=[item+'_1_52'])],
                   axis=1)
    temp_array = np.zeros(len(df))
    for var in liste2:
        temp_array = temp_array + df[var].fillna(-0.001).to_numpy()
    df = pd.concat([df, pd.DataFrame(temp_array, columns=[item+'_53_104'])],
                   axis=1)

# In[6]: treatments

temp_array = np.zeros(len(df))-1
idx = (df.status == 1) & (df.JCAH_1_52 < 1) & (df.JCVH_1_52 < 1) & \
    (df.JCAH_53_104 < 1) & (df.JCVH_53_104 < 1)
temp_array[idx] = 0
idx = (df.status == 2) & (df.JCAH_1_52 == 0) & (df.JCVH_1_52 == 0) & \
    (df.JCAH_53_104 == 0) & (df.JCVH_53_104 == 0)
temp_array[idx] = 11
idx = (df.status == 2) & (df.JCAH_1_52 == 0) & (df.JCVH_1_52 == 0) & \
    (df.JCAH_53_104 > 0) & (df.JCVH_53_104 >= 0) & \
    (df.JCAH_53_104 >= df.JCVH_53_104)
temp_array[idx] = 12
idx = (df.status == 2) & (df.JCAH_1_52 == 0) & (df.JCVH_1_52 == 0) & \
    (df.JCAH_53_104 >= 0) & (df.JCVH_53_104 > 0) & \
    (df.JCAH_53_104 < df.JCVH_53_104)
temp_array[idx] = 13
idx = (df.status == 2) & (df.JCAH_1_52 > 0) & (df.JCVH_1_52 >= 0) & \
    (df.JCAH_1_52 >= df.JCVH_1_52) & (df.JCAH_53_104 == 0) & \
    (df.JCVH_53_104 == 0)
temp_array[idx] = 21
idx = (df.status == 2) & (df.JCAH_1_52 > 0) & (df.JCVH_1_52 >= 0) & \
    (df.JCAH_1_52 >= df.JCVH_1_52) & (df.JCAH_53_104 > 0) & \
    (df.JCVH_53_104 >= 0) & (df.JCAH_53_104 >= df.JCVH_53_104)
temp_array[idx] = 22
idx = (df.status == 2) & (df.JCAH_1_52 > 0) & (df.JCVH_1_52 >= 0) & \
    (df.JCAH_1_52 >= df.JCVH_1_52) & (df.JCAH_53_104 >= 0) & \
    (df.JCVH_53_104 > 0) & (df.JCAH_53_104 < df.JCVH_53_104)
temp_array[idx] = 23
idx = (df.status == 2) & (df.JCAH_1_52 >= 0) & (df.JCVH_1_52 > 0) & \
    (df.JCAH_1_52 < df.JCVH_1_52) & (df.JCAH_53_104 == 0) & \
    (df.JCVH_53_104 == 0)
temp_array[idx] = 31
idx = (df.status == 2) & (df.JCAH_1_52 >= 0) & (df.JCVH_1_52 > 0) & \
    (df.JCAH_1_52 < df.JCVH_1_52) & (df.JCAH_53_104 > 0) & \
    (df.JCVH_53_104 >= 0) & (df.JCAH_53_104 >= df.JCVH_53_104)
temp_array[idx] = 32
idx = (df.status == 2) & (df.JCAH_1_52 >= 0) & (df.JCVH_1_52 > 0) & \
    (df.JCAH_1_52 < df.JCVH_1_52) & (df.JCAH_53_104 >= 0) & \
    (df.JCVH_53_104 > 0) & (df.JCAH_53_104 < df.JCVH_53_104)
temp_array[idx] = 33
temp_array[temp_array == -1] = np.NaN
df = pd.concat([df, pd.DataFrame(temp_array, columns=['treat'])], axis=1)

print('\n table 3: sequences of binary treatments')
print(df.treat.value_counts(dropna=False).sort_index())

# In[7]: treaty0, treaty1: treatment states in years 1+2

temp_array = np.zeros(len(df))-1
idx = (df.status == 1) & (df.JCAH_1_52 < 1) & (df.JCVH_1_52 < 1)
temp_array[idx] = 0
idx = (df.status == 2) & (df.JCAH_1_52 == 0) & (df.JCVH_1_52 == 0)
temp_array[idx] = 1
idx = (df.status == 2) & (df.JCAH_1_52 > 0) & (df.JCVH_1_52 >= 0) & \
    (df.JCAH_1_52 >= df.JCVH_1_52)
temp_array[idx] = 2
idx = (df.status == 2) & (df.JCAH_1_52 >= 0) & (df.JCVH_1_52 > 0) & \
    (df.JCAH_1_52 < df.JCVH_1_52)
temp_array[idx] = 3
temp_array[temp_array == -1] = np.NaN
df = pd.concat([df, pd.DataFrame(temp_array, columns=['treaty0'])], axis=1)

temp_array = np.zeros(len(df))-1
idx = (df.status == 1) & (df.JCAH_53_104 < 1) & (df.JCVH_53_104 < 1)
temp_array[idx] = 0
idx = (df.status == 2) & (df.JCAH_53_104 == 0) & (df.JCVH_53_104 == 0)
temp_array[idx] = 1
idx = (df.status == 2) & (df.JCAH_53_104 > 0) & (df.JCVH_53_104 >= 0) & \
    (df.JCAH_53_104 >= df.JCVH_53_104)
temp_array[idx] = 2
idx = (df.status == 2) & (df.JCAH_53_104 >= 0) & (df.JCVH_53_104 > 0) & \
    (df.JCAH_53_104 < df.JCVH_53_104)
temp_array[idx] = 3
temp_array[temp_array == -1] = np.NaN
df = pd.concat([df, pd.DataFrame(temp_array, columns=['treaty1'])], axis=1)

# In[8]: treatment dummy for jobcorps, 0=randomized out, 1=randomized in

df['status'] = df.status-1

# In[9]: outcomes in year 4

liste = ['workq13', 'workq14', 'workq15', 'workq16']
df['emply4'] = df[liste].sum(axis='columns', min_count=1)
df.loc[df.emply4 > 0, 'emply4'] = 1

# In[10]: covariates

WEEKS = 52
MONTHS = 12

# transform to dummies
df['WAY_STAY'] = df.WAY_STAY-1
df['howspoke'] = df.howspoke-1
idx = (df.ANY_ED1 > 0) & (df.ANY_ED1 < 1)
df.loc[idx, 'ANY_ED1'] = 1
idx = (df.EVARRST1 > 0) & (df.EVARRST1 < 1)
df.loc[idx, 'EVARRST1'] = 1
idx = (df.YR_WORK1 > 0) & (df.YR_WORK1 < 1)
df.loc[idx, 'YR_WORK1'] = 1
idx = (df.GOTAFDC1 > 0) & (df.GOTAFDC1 < 1)
df.loc[idx, 'GOTAFDC1'] = 1
idx = (df.GOTOTHW1 > 0) & (df.GOTOTHW1 < 1)
df.loc[idx, 'GOTOTHW1'] = 1
idx = (df.GOTFS1 > 0) & (df.GOTFS1 < 1)
df.loc[idx, 'GOTFS1'] = 1

# replace values between zero and one by nan's
liste = ['burglary', 'robbery', 'assault', 'larceny', 'drugviol', 'othpers',
         'othmisc', 'SERCR_S1', 'SERCR_S2', 'SERCR_S3', 'SERCR_S4', 'SERCR_S5',
         'SERCR_S6', 'SERCR_S7', 'GUILTY2', 'wksjail', 'PENDING2', 'COPPLEA2',
         'SERCR_C1', 'SERCR_C2', 'SERCR_C3', 'SERCR_C4', 'SERCR_C5',
         'SERCR_C6', 'SERCR_C7', 'ASSLT_C2', 'ROB_C2', 'BURGL_C2', 'LARCNYC2',
         'DRVIOLC2', 'OTHPERC2', 'OTHMSCC2', 'EVJAIL2', 'PAROLE2']
for i in liste:
    df.loc[(df[i] > 0) & (df[i] < 1), i] = np.nan

# generate dummies for covariates based on weekly data
liste = ['JCHX', 'JCAXH', 'JCVOH', 'EDTNX', 'EDAXH', 'DGTRH', 'WORKH', 'CCEMP',
         'CCEDT', 'CCJC']
liste = df_baseline.columns[211:-1].tolist() + \
    [i+str(j) for i in liste for j in range(1, WEEKS+1)]

# transform values from strings to zeros, ones, or nan's
x = set([j for i in liste for j in df[i]])
y = list()
for i in x:
    try:
        y.append(1*(int(''.join(re.findall('\\d', i))) > 0))
    except(ValueError, TypeError):
        if (i == '') or (i == 0):
            y.append(0)
        else:
            y.append(np.nan)
for i in liste:
    df[i] = df[i].map(dict(zip(x, y)))

# generate dummies for covariates based on monthly and weekly data
liste1 = ['UIH']
liste1 = [i+str(j) for i in liste1 for j in range(1, WEEKS+1)]
liste2 = ['AFDCH', 'FSH', 'SSIH', 'GAH', 'ANYPH', 'WICH']
liste2 = [i+str(j) for i in liste2 for j in range(1, MONTHS+1)]
liste = liste1+liste2
x = pd.Series([j for i in liste for j in df[i]]).unique()
y = list()
for i in x:
    if i == 0:
        y.append(0)
    elif i >= 1:
        y.append(1)
    elif pd.isna(i):
        y.append(np.nan)
for i in liste:
    df[i] = df[i].map(dict(zip(x, y)))

# new levels, analogous to baseline variable marriage
df.loc[df.f_f23 == 1, 'f_f23'] = 12
df.loc[(df.f_f23 == 2) | (df.f_f23 == 3) | (df.f_f23 == 4), 'f_f23'] = 14
df.loc[df.f_f23 == 5, 'f_f23'] = 13
df.loc[df.f_f23 == 6, 'f_f23'] = 11
df.loc[(df.f_f23 == 8) | (df.f_f23 == 9), 'f_f23'] = np.nan
df['f_f23'] = df.f_f23-10

# assign nan's
df.loc[(df.f_g78 == 8) | (df.f_g78 == 9), 'f_g78'] = np.nan
df.loc[(df.f_g80 == 8), 'f_g80'] = np.nan
df.loc[(df.f_g81 == 8) | (df.f_g81 == 9), 'f_g81'] = np.nan
df.loc[(df.f_g82 == 7) | (df.f_g82 == 8) | (df.f_g82 == 9), 'f_g82'] = np.nan

# In[11]: dataframe (missings: -1)

var_d = ['status', 'treat', 'treaty0', 'treaty1']
var_y = ['emply4']
var_x0 = ['jcmsa', 'age', 'RACE_ETH', 'NTV_LANG', 'HH14', 'WELF_KID',
          'HGC_MOTH', 'M_WORK14', 'OCC_MOTH', 'HGC_FATH', 'F_WORK14',
          'OCC_FATH', 'marriage', 'haschld', 'proplive', 'PREGN_RA', 'old',
          'yng', 'othwith', 'nchld', 'ageparnt', 'NUMB_HH', 'R_HEAD', 'hhmemb',
          'HOUS_ARR', 'PAY_RENT', 'hgc', 'HS_D', 'GED_D', 'VOC_D', 'OTH_DEG',
          'inschool', 'ANY_ED1', 'N_ED_CAT', 'monined', 'reasleft', 'REC_ED',
          'TYPEED_R', 'NHRSED_R', 'REASED_R', 'numbjobs', 'evworkb',
          'YR_WORK1', 'EARN_YR', 'mosinjob', 'REC_JOB', 'OCC_R', 'HRSWK_JR',
          'hrwager', 'COOP_R', 'GOVPRG_R', 'leftjobr', 'rslftjr', 'MOS_AFDC',
          'MOS_OTHW', 'MOS_FS', 'GOT_ANYW', 'MOS_ANYW', 'GOTAFDC1', 'GOTOTHW1',
          'GOTFS1', 'HH_INC', 'PERS_INC', 'health', 'sick', 'typehlth',
          'timesick', 'EV_CIG', 'PY_CIG', 'EV_ALCHL', 'PY_ALCHL', 'EV_POT',
          'PY_POT', 'EV_COKE', 'PY_COKE', 'EV_CRACK', 'PY_CRACK', 'EV_HROIN',
          'PY_HROIN', 'EV_SPEED', 'PY_SPEED', 'EV_LSD', 'PY_LSD', 'EV_OTHDR',
          'PY_OTHDR', 'EV_INJCT', 'DRUG_TRT', 'MOUT_TRT', 'MOS_TRTR',
          'FRQ_CIG', 'FRQ_ALC', 'FRQ_POT', 'FRQ_COKE', 'FRQ_CRAC', 'FRQ_HERN',
          'FRQ_SPED', 'FRQ_LSD', 'FRQ_INJ', 'FRQ_OTH', 'narrcat', 'EVARRST1',
          'RC_ARRST', 'MARRCAT1', 'agearcat', 'burglary', 'robbery', 'assault',
          'larceny', 'drugviol', 'othpers', 'othmisc', 'SERCR_S1', 'SERCR_S2',
          'SERCR_S3', 'SERCR_S4', 'SERCR_S5', 'SERCR_S6', 'SERCR_S7',
          'N_GUILTY', 'GUILTY2', 'wksjail', 'PENDING2', 'COPPLEA2', 'SERCR_C1',
          'SERCR_C2', 'SERCR_C3', 'SERCR_C4', 'SERCR_C5', 'SERCR_C6',
          'SERCR_C7', 'ASSLT_C2', 'ROB_C2', 'BURGL_C2', 'LARCNYC2', 'DRVIOLC2',
          'OTHPERC2', 'OTHMSCC2', 'EVJAIL2', 'PAROLE2', 'HEAR_JC', 'FROM_OA',
          'KNEW_JC', 'INFO_JC', 'R_HOME', 'R_COMM', 'R_TRAIN', 'R_CRGOAL',
          'R_GETGED', 'R_NOWORK', 'R_OTHER', 'mostimpr', 'othimpr', 'E_MATH',
          'E_READ', 'E_ALONG', 'E_CONTRL', 'E_ESTEEM', 'E_SPCJOB',  'E_FRIEND',
          'knewcntr', 'imprcntr', 'knewjob', 'typejobb', 'EARN_CMP',
          'hadworry', 'typeworr', 'TALK_PAR', 'IMP_PAR', 'ENCR_PAR',
          'TALK_REL', 'IMP_REL', 'ENCR_REL', 'TALK_FRD', 'IMP_FRD', 'ENCR_FRD',
          'TALK_TCH', 'IMP_TCH', 'ENCR_TCH', 'TALK_CW', 'IMP_CW', 'ENCR_CW',
          'TALK_PRO', 'IMP_PRO', 'ENCR_PRO', 'TALK_CHL', 'IMP_CHL', 'ENCR_CHL',
          'TALK_ADL', 'IMP_ADL', 'ENCR_ADL', 'ENCR_JCR', 'howspoke',
          'telemode', 'placeipc', 'talkstay', 'WAY_STAY', 'rstaycat',
          'tstaycat', 'talkvstf', 'talktold', 'tradwant', 'chncetrd',
          'totalhrs', 'VSTF_CAT'] + \
    df_baseline.columns[211:-1].tolist()+['currjob', 'female']
liste = ['JCAXH', 'EDTNX', 'EDAXH', 'DGTRH', 'WORKH', 'CCEMP', 'CCEDT', 'CCJC',
         'UIH', 'HWH', 'EARNH', 'EDTA', 'EDTV', 'AFDCH', 'FSH', 'SSIH', 'GAH',
         'ANYPH', 'WICH']
var_x1 = [i+str(j) for i in liste[:13] for j in range(1, WEEKS+1)] + \
    [i+str(j) for i in liste[13:] for j in range(1, MONTHS+1)]
var_x1 = var_x1 + \
    ['f_f23', 'f_g78', 'f_g80', 'f_g81', 'f_g82', 'evarrq1', 'evarrq2',
     'evarrq3', 'evarrq4', 'narry1', 'narc1_1', 'narc1_2', 'narc1_3',
     'narc1_4', 'narc1_5', 'narc1_6', 'narc1_7', 'narc1_8', 'anycra12',
     'moneya12', 'numvic12', 'numcra12', 'carst12', 'burgd12', 'assld12',
     'robbd12', 'rippd12', 'nassld12', 'nburgd12', 'nrobbd12', 'nrippd12',
     'ncarst12', 'cig12', 'drink12', 'pot12', 'coke12', 'crack12', 'heroin12',
     'speed12', 'lsd12', 'inject12', 'othdrg12', 'fcig12', 'fpot12', 'fcoke12',
     'fcrack12', 'fhern12', 'fspeed12', 'flsd12', 'finjct12', 'fothdg12',
     'hard12', 'anydr12', 'fdrink12', 'health12', 'pe_prb12', 'whoursy1',
     'eparent1', 'egrand1', 'eotrel1', 'epnrel1', 'eunrel1', 'edaycar1',
     'eschool1', 'enrel1', 'evrel1']
liste = var_d+var_y+var_x0+var_x1
df[liste] = df[liste].fillna(value=-1)

# In[12]: variable statistics
# variable names
# variable descriptions
# data types: 1=dummy, 2=categorical, 3=numeric
# variable classes: 1=treatment, 2=outcome, 3(4)=covariate year 0(1)

# select variable names and description for: treatment (d), outcome(y),
# and covariates(x0,x1)
liste = [var_d[0]]+var_x0+var_x1
df_name = pd.DataFrame(data=liste, columns=['name'])
df_vars1 = df_name.merge(df_name_label, how='left', on=['name'])

# data types: 1=dummy, 2=categorical, 3=numeric
df_vars1['type'] = \
    [1,  2,  3,  2,  2,  2,  2,  3,  1,  2,  3,  1,  2,  2,  1,  2,  2,  3,
     3,  2,  3,  3,  3,  1,  2,  2,  1,  3,  1,  1,  1,  1,  1,  1,  2,  3,  2,
     2,  2,  3,  2,  3,  1,  1,  3,  3,  2,  2,  3,  3,  1,  1,  1,  2,  3,  3,
     3,  1,  3,  1,  1,  1,  2,  2,  1,  1,  2,  3,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  3,  2,  2,  2,  2,  2,
     2,  2,  2,  2,  2,  2,  2,  1,  2,  2,  2,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  3,  1,  3,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  2,  1,  1,  2,  1,  1,  1,  1,  1,  1,  1,
     2,  2,  1,  1,  1,  1,  1,  1,  1,  1,  2,  1,  2,  3,  1,  2,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  2,  1,  2,  2,  1,  1,  2,  2,  1,  1,  1,  2,  3,  2,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  3,  3,  3,  3,  3,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
     1,  1,  1,  2,  2,  1,  1,  1,  1,  1,  1,  1,  3,  3,  3,  3,  3,  3,  3,
     3,  3,  1,  3,  3,  3,  1,  1,  1,  1,  1,  3,  3,  3,  3,  3,  1,  1,  1,
     1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  1,  1,  2,
     2,  1,  3,  1,  1,  1,  1,  1,  1,  1,  1,  1]

# further variable names, descriptions, and data types
liste_type = [2, 2, 2, 1]
liste_name = ['treat', 'treaty0', 'treaty1', 'emply4']
liste_labels = [str('0:controls-no educ/train in job corps in yr 1+2; '
                    '1:treated-no educ/train in job corps in yr 1/2; '
                    '2:treated-academic educ in job corps in yr 1/2; '
                    '3:treated-vocational train in job corps in yr 1/2'),
                str('0:controls-no educ/train in job corps in yr 1; '
                    '1:treated-no educ/train in job corps in yr 1; '
                    '2:treated-academic educ in job corps in yr 1; '
                    '3:treated-vocational train in job corps in yr 1'),
                str('0:controls-no educ/train in job corps in yr 2; '
                    '1:treated-no educ/train in job corps in yr 2; '
                    '2:treated-academic educ in job corps in yr 2; '
                    '3:treated-vocational train in job corps in yr 2'),
                'employed in year 4']
woerterbuch = {'name': liste_name, 'varlab': liste_labels, 'type': liste_type}
df_vars2 = pd.DataFrame(data=woerterbuch)
df_vars = pd.concat([df_vars1.iloc[:1], df_vars2, df_vars1.iloc[1:]],
                    ignore_index=True)

# variable classes: 1=treatment, 2=outcome, 3(4)=covariate year 0(1)
df_vars['class'] = 0
for i in var_d:
    df_vars.loc[df_vars.name == i, 'class'] = 1
for i in var_y:
    df_vars.loc[df_vars.name == i, 'class'] = 2
for i in var_x0:
    df_vars.loc[df_vars.name == i, 'class'] = 3
for i in var_x1:
    df_vars.loc[df_vars.name == i, 'class'] = 4

# missings
df_vars['miss'] = 0
liste_var = var_d+var_y+var_x0+var_x1
liste_miss = []
for c, value in enumerate(liste_var):
    idx = 1*(df[value] == -1)
    idx_sum = idx.sum()
    df_vars.loc[c, 'miss'] = idx_sum
    if idx_sum > 0:
        liste_miss = liste_miss+[value]

# compute means considering non-missing element only
df_means = df.groupby(['treat'])[df_vars.name].apply(
    lambda x: x[x >= 0].mean()).transpose()
df_means.columns = ['treat_miss', '0', '11', '12', '13', '21', '22', '23',
                    '31', '32', '33']
df_vars_means = pd.concat([df_vars, df_means.reset_index(drop=True)],
                          axis='columns')

# In[13]: table S1 (description of regressors and means across treatments)

table_s1 = df_vars_means[5:].round(2).copy()
table_s1.index = range(len(table_s1))
table_s1['class'] = table_s1['class']-3
table_s1.columns = ['name', 'description', 'type', 'x', 'miss', '-1',
                    '00', '11', '12', '13', '21', '22', '23', '31', '32', '33']
table_s1.name = table_s1.name.str.upper()

# In[14]: replace missings values by dummies for non-numerical variables

vars_cat = df[liste_var].columns[df_vars.type <= 2]  # var names(non-numeric)
for i in vars_cat:
    df[i] = df[i].astype(int)
liste_dummies = []

# non-numerical variables with missings
liste = [i for i in vars_cat if not (sum(df[i] == -1) == 0) | (i in var_y)]
for i in liste:
    if i not in var_d:
        dummies = pd.get_dummies(df[i], prefix=i).iloc[:, :-1]
        df = pd.concat([df, dummies], axis=1)
        liste_dummies = liste_dummies+dummies.columns.tolist()
    elif i in var_d:
        dummies = pd.get_dummies(df[i], prefix=i).iloc[:, 1:]
        df = pd.concat([df, dummies], axis=1)
        liste_dummies = liste_dummies+dummies.columns.tolist()

# no non-numerical variables with more than 2 categories without missings
# in vars_cat (only the dummies 'status' and 'female')

# In[15]: replace missings values by means for numerical variables

vars_cat = df[liste_var].columns[df_vars.type == 3]  # variable names (numeric)
liste = [i for i in vars_cat if not (sum(df[i] == -1) == 0) | (i in var_y)
         | (i in var_d)]

for i in liste:
    df.loc[df[i] == -1, i] = df.loc[df[i] != -1, i].mean()

# In[16]: generate dataset for estimation of causal effects

r = [var_d[0]]  # random assignment indicator
d = var_d[1:4]  # treatments
y = var_y  # outcome


def generate_x(var_x, df_vars, df, liste_dummies, c):
    """
    Create x including dummies and standardize numeric covariates.

    input arguments var_x: variable names (list)
                    df_vars: statistics of x (dataframe)
                    df: data (dataframe)
                    liste_dummies: variable names of dummies (list)
                    c: variable class (integer, 3/4=x in period 0/1)
    return argument x: names of covariates (list)
    """
    x = []
    for i in var_x:
        idx = df_vars.name == i
        # set the boolean b to true if x is numeric
        b = (df_vars.loc[idx, 'type'] == 3) & (df_vars.loc[idx, 'class'] == c)
        if b.tolist()[0]:
            x = x+[i]
            df[i] = (df[i]-df[i].mean())/(2*df[i].std())
        # set the boolean b to true if x is non-numeric and has no missings
        b = (df_vars.loc[idx, 'type'] <= 2) & (df_vars.loc[idx, 'miss'] == 0)
        if b.tolist()[0]:
            x = x+[i]
        for j in liste_dummies:
            if j.startswith(i+'_'):
                x = x+[j]
    return x


x0 = generate_x(var_x0, df_vars, df, liste_dummies, 3)
x1 = generate_x(var_x1, df_vars, df, liste_dummies, 4)
liste = r+d+y+x0+x1
df_data_csv = df[r+d[1:]+y+x0+x1]
df_data_csv.columns = df_data_csv.columns.str.upper()
df_data_csv.to_csv('data.csv', index=False)

# In[17]: table 4: mean outcome conditional on treatment sequence

liste = [0, 11, 21, 22, 33, -1]
table4 = pd.DataFrame(
    {'mean': [round(df.loc[(df['treat'] == i), 'emply4'].mean(), 2)
              for i in liste],
     'missings': [(df.loc[(df['treat'] == i), 'emply4'] == -1).sum()
                  for i in liste]}, index=[0, 11, 21, 22, 33, 'missings'])
print('\n table4 \n', table4.rename(index={-1: 'missings'}))

# In[18]: table 5: number of covariates

# regressors prior to first treatment: x0
x0_d = ((df_vars['class'] == 3) & (df_vars['type'] == 1)).sum()  # dummies
x0_c = ((df_vars['class'] == 3) & (df_vars['type'] == 2)).sum()  # categorical
x0_n = ((df_vars['class'] == 3) & (df_vars['type'] == 3)).sum()  # numeric
x0_t = x0_d+x0_c+x0_n  # total, raw variables x0
x0_d_processed = 0  # dummies, processed covariates x0
for i in x0:
    if df[i].value_counts().sort_index().index.tolist() == [0, 1]:
        x0_d_processed += 1
x0_n_processed = len(x0) - x0_d_processed  # numeric, processed covariates x0

# regressors prior to second treatment: x1
x1_d = ((df_vars['class'] == 4) & (df_vars['type'] == 1)).sum()  # dummies
x1_c = ((df_vars['class'] == 4) & (df_vars['type'] == 2)).sum()  # categorical
x1_n = ((df_vars['class'] == 4) & (df_vars['type'] == 3)).sum()  # numeric
x1_t = x1_d+x1_c+x1_n  # total, raw variables x1
x1_d_processed = 0  # dummies, processed covariates x1
for i in x1:
    if df[i].value_counts().sort_index().index.tolist() == [0, 1]:
        x1_d_processed += 1
x1_n_processed = len(x1) - x1_d_processed  # numeric, processed covariates x1

table5 = pd.DataFrame({
    'column1': ['raw variables', 'dummy', 'categorical', 'numeric', 'total',
                'processed variables', 'dummy', 'numeric', 'total'],
    'column2': ['x0', x0_d, x0_c, x0_n, x0_t, 'x0', x0_d_processed,
                x0_n_processed, len(x0)],
    'column3': ['x1', x1_d, x1_c, x1_n, x1_t, 'x1', x1_d_processed,
                x1_n_processed, len(x1)]
    })
print('\n table5 \n', table5)
